import sys
import os
import logging
import re
import json
from typing import Dict, List, Any, Optional, Tuple

# 添加项目根路径到 sys.path
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..'))

# 导入 MySQLUtil
from shared.utils.MySQLUtil import MySQLUtil

# 设置日志
logger = logging.getLogger(__name__)

def get_跨专业升学原因分布(
    project_id: int,
    questionnaire_ids: List[int],
    product_code: Optional[str] = None,
    project_code: Optional[str] = None,
    region_code: Optional[str] = None,
    education: Optional[str] = None
) -> Dict[str, Any]:
    """
    跨专业升学原因分布 - 指标计算函数
    
    ## 指标说明
    该指标用于统计学生跨专业升学原因的各项占比情况，包括：个人兴趣、就业前景、考取难度、学习压力、志愿调剂等选项的占比分布。
    
    ## Args
        project_id (int): 项目ID，用于查询项目配置信息
        questionnaire_ids (List[int]): 问卷ID集合，用于确定数据范围
        product_code (Optional[str]): 产品编码，用于路由到特定计算逻辑
        project_code (Optional[str]): 项目编码，用于路由到特定计算逻辑
        region_code (Optional[str]): 区域编码，用于路由到特定计算逻辑
        education (Optional[str]): 学历筛选条件，可选值：[本科毕业生，专科毕业生，硕士研究生，博士研究生]
        
    ## 示例
    ### 输入
    ```json
    {
        "project_id": 5895,
        "questionnaire_ids": [11158, 11159]
    }
    ```
    
    ### 输出
    ```json
    {
        "success": true,
        "message": "ok",
        "code": 0,
        "result": [
            {
                "key": "出于个人兴趣",
                "val": 0.35
            },
            {
                "key": "就业前景好",
                "val": 0.25
            },
            {
                "key": "考取难度低",
                "val": 0.15
            },
            {
                "key": "学习压力小",
                "val": 0.1
            },
            {
                "key": "志愿外被调剂",
                "val": 0.1
            },
            {
                "key": "其他",
                "val": 0.05
            }
        ]
    }
    ```
    """
    logger.info(f"开始计算指标: 跨专业升学原因分布, 项目ID: {project_id}")
    
    try:
        db = MySQLUtil()  

        # 1. 查询项目配置信息
        project_sql = """
        SELECT client_code, item_year, dy_target_items, split_tb_paper 
        FROM client_item 
        WHERE id = %s
        """
        project_info = db.fetchone(project_sql, (project_id,))
        if not project_info:
            raise ValueError(f"未找到项目ID={project_id}的配置信息")

        client_code = project_info['client_code']
        item_year = project_info['item_year']
        split_tb_paper = project_info['split_tb_paper']
        
        logger.info(f"项目配置: client_code={client_code}, item_year={item_year}, split_tb_paper={split_tb_paper}")

        # 2. 计算 shard_tb_key
        shard_tb_key = re.sub(r'^[A-Za-z]*0*', '', client_code)
        logger.info(f"计算得到 shard_tb_key: {shard_tb_key}")

        # 3. 查询问卷信息
        questionnaire_sql = f"""
        SELECT id, dy_target 
        FROM wt_template_customer 
        WHERE id IN ({','.join(['%s'] * len(questionnaire_ids))})
        """
        questionnaires = db.fetchall(questionnaire_sql, tuple(questionnaire_ids))
        if not questionnaires:
            raise ValueError(f"未找到问卷ID集合={questionnaire_ids}的配置信息")
        
        logger.info(f"查询到问卷信息: {questionnaires}")

        # 4. 过滤特定调研对象的问卷
        valid_questionnaire_ids = [q['id'] for q in questionnaires if q['dy_target'] == 'GRADUATE_SHORT']
        if not valid_questionnaire_ids:
            raise ValueError("未找到目标调研对象的问卷ID")
            
        logger.info(f"找到有效问卷ID: {valid_questionnaire_ids}")

        # 5. 查询问题信息
        question_sql = """
        SELECT id, wt_code, wt_obj 
        FROM wt_template_question_customer 
        WHERE cd_template_id = %s AND wt_code = 'T00000392' AND is_del = 0
        """
        question_info = db.fetchone(question_sql, (valid_questionnaire_ids[0],))
        if not question_info:
            raise ValueError("未找到指定问题编码的问题信息")
            
        logger.info(f"找到问题信息: {question_info['id']}")

        # 6. 解析问题选项
        wt_obj = json.loads(question_info['wt_obj'])
        options = []
        for item in wt_obj['itemList']:
            options.append({
                'key': item['key'],
                'val': item['val'],
                'weight': item.get('weight', 1)
            })

        # 7. 构建动态表名
        answer_table = f"re_dy_paper_answer_{split_tb_paper}"
        student_table = f"dim_client_target_baseinfo_student_{item_year}"

        # 8. 构建SQL查询条件
        education_condition = ""
        if education:
            education_condition = f"AND s.education = '{education}'"

        # 9. 执行SQL查询
        sql = f"""
        SELECT
            SUM(CASE WHEN t1.c1 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '出于个人兴趣',
            SUM(CASE WHEN t1.c2 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '就业前景好',
            SUM(CASE WHEN t1.c3 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '考取难度低',
            SUM(CASE WHEN t1.c4 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '学习压力小',
            SUM(CASE WHEN t1.c5 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '志愿外被调剂',
            SUM(CASE WHEN t1.c6 = 1 THEN 1 ELSE 0 END)/COUNT(*) as '其他'
        FROM
            {answer_table} t1
            JOIN {student_table} s ON t1.target_no = s.target_no
        WHERE
            t1.cd_template_id = %s
            AND t1.wid = %s
            AND t1.ans_true = 1
            AND s.shard_tb_key = %s
            AND s.item_year = %s
            {education_condition}
        """
        params = (valid_questionnaire_ids[0], question_info['id'], shard_tb_key, item_year)
        result = db.fetchone(sql, params)
        if not result:
            raise ValueError("未找到有效的答案数据")
        
        # 10. 处理查询结果
        result_data = []
        for option in options:
            key = option['key']
            val = result.get(option['val'], 0)
            result_data.append({
                "key": option['val'],
                "val": float(val)
            })

        logger.info(f"指标 '跨专业升学原因分布' 计算成功")
        return {
            "success": True,
            "message": "ok",
            "code": 0,
            "result": result_data
        }

    except Exception as e:
        logger.error(f"计算指标 '跨专业升学原因分布' 时发生错误: {str(e)}", exc_info=True)
        return {
            "success": False,
            "message": f"数据获取失败: 跨专业升学原因分布",
            "code": 500,
            "error": str(e)
        }